#ifndef DBSCAN_H_
#define DBSCAN_H_

#pragma once
#include <Eigen/Core>
#include <Eigen/Dense>
#include <algorithm>
#include <deque>
#include <iostream>
#include <vector>

using std::deque;
using std::vector;

// struct Point2 {
//   float x = 0.0f;
//   float y = 0.0f;
// };

// namespace TBD
// {

/**
 * @brief Density-based spatial clustering of applications with noise.
 * implementation of DBSCAN algorithm from https://zh.wikipedia.org/wiki/DBSCAN.
 * using distance matrix to reduce repeated computation of distance, which would
 * increase space complexity o(n^2).
 */
template <class T>
class DBScan {
 private:
  // minimum number of points in a cluster
  int min_num_;

  // maximum distance threshold of two points
  double eps_;

  // cluster attributes of each point
  vector<int> _labels;

  // save the distance between all points
  Eigen::MatrixXd _distance_mat;

  // TODO: you should define the distance function of your struct/class
  /**
   * @brief Get the Distance object
   * @param p1 object/point 1
   * @param p2 object/point 2
   * @return double
   */
  inline double GetDistance(const Eigen::Vector2d &p1,
                            const Eigen::Vector2d &p2) {
    return (p1 - p2).norm();
  }
  inline float GetDistance(const Eigen::Vector2f &p1,
                           const Eigen::Vector2f &p2) {
    return (p1 - p2).norm();
  }
  // inline double GetDistance(const Point2 &p1, const Point2 &p2) {
  //   return sqrt(std::pow((p1.x - p2.x), 2) + pow((p1.y - p2.y), 2));
  // }

  /**
   * @brief core function of dbscan
   * @param all_pts, all points/objects
   * @return int, number of clusters
   */
  int Run(const vector<T> &all_pts);

  /**
   * @brief find neighbors within a fixed distance, and generate a cluster
   * @param pt_idx, point index of all points
   * @param cluster_idx, cluster index
   * @return bool, true:cluster is generated, false:point/object is noise
   */
  bool ExpandCluster(int pt_idx, int &cluster_idx);

 public:
  /**
   * @brief Construct a new DBScan object
   * @param eps, define maximum distance threshold of two points
   * @param min_num, define minimum number of points in a cluster
   */
  DBScan(float eps, int min_num) : eps_(eps), min_num_(min_num) {}
  DBScan() : eps_(0), min_num_(0) {}

  /**
   * @brief Destroy the DBScan object
   */
  ~DBScan() {}

  /**
   * @brief Get the Clusters object
   * @param all_pts, all points/objects
   * @return vector<vector<T>>, 0:noise points, 1-n:cluster points
   */
  vector<vector<T>> GetClusters(std::vector<T> all_pts);

  /**
   * @brief Get the Labels of object
   * @return vector<int>, get the cluster label corresponds to the index of each
   * object
   */
  vector<int> GetLabels() { return _labels; }

  /// @brief set eps
  /// @param
  inline void SetEPS(int val) { eps_ = val; }

  /// @brief set min number
  /// @param
  inline void SetMinNum(int val) { min_num_ = val; }
};

template <typename T>
int DBScan<T>::Run(const vector<T> &all_pts) {
  int size = all_pts.size();
  int cluster_idx = 1;

  // -1:noise, 0:unlabel, >1:cluster index
  _labels.resize(size, 0);

  // 1. calculate the distance between each two points, opt(using OpenMP if data
  // is big)
  _distance_mat = Eigen::MatrixXd::Zero(size, size);
  // #pragma omp parallel for schedule(runtime)
  for (int i = 0; i < size; ++i) {
    for (int j = i; j < size; ++j) {
      if (i != j) {
        _distance_mat(i, j) = GetDistance(all_pts.data()[i], all_pts.data()[j]);
        _distance_mat(j, i) = _distance_mat(i, j);
      }
    }
  }
  // std::cout << _distance_mat << std::endl;

  // 2. do clustering
  for (size_t i = 0; i < size; i++) {
    if (_labels[i] != 0) continue;
    ExpandCluster(i, cluster_idx);
  }

  return cluster_idx - 1;
}

template <typename T>
bool DBScan<T>::ExpandCluster(int pt_idx, int &cluster_idx) {
  // 1.region query
  deque<int> seeds_idx;
  seeds_idx.emplace_back(pt_idx);
  for (size_t col = 0; col < _distance_mat.cols(); col++) {
    if (col == pt_idx) continue;
    if (_distance_mat(pt_idx, col) < eps_) {
      seeds_idx.emplace_back(col);
    }
  }

  // 2.check point numbers of neighbors, whether its noise
  if (seeds_idx.size() < min_num_) {
    _labels[pt_idx] = -1;
    return false;
  }
  // 3.label point
  for (size_t i = 0; i < seeds_idx.size(); i++) {
    _labels[seeds_idx[i]] = cluster_idx;
  }

  // 4.do neighbors clustering
  seeds_idx.pop_front();
  while (!seeds_idx.empty()) {
    auto &row = seeds_idx.front();
    // region query
    vector<int> temp_idx;
    temp_idx.reserve(_distance_mat.cols());
    for (size_t col = 0; col < _distance_mat.cols(); col++) {
      if (_distance_mat(row, col) <= eps_) temp_idx.emplace_back(col);
    }
    // check point numbers of neighbors, whether its noise
    if (temp_idx.size() > min_num_) {
      // label point
      for (size_t i = 0; i < temp_idx.size(); i++) {
        // if already label, abort
        if (_labels[temp_idx[i]] >= 1) continue;
        if (_labels[temp_idx[i]] == 0) seeds_idx.emplace_back(temp_idx[i]);
        _labels[temp_idx[i]] = cluster_idx;
      }
    }

    seeds_idx.pop_front();
  }

  cluster_idx++;
  return true;
}

template <typename T>
vector<vector<T>> DBScan<T>::GetClusters(std::vector<T> all_pts) {
  int clusters_num = this->Run(all_pts);
  vector<vector<T>> result(clusters_num + 1);

  for (size_t i = 0; i < _labels.size(); i++) {
    auto &label = _labels[i];
    if (label < 1) {
      // noise point, index=0
      result[label + 1].emplace_back(all_pts[i]);
      continue;
    }

    result[label].emplace_back(all_pts[i]);
  }

  return result;
}
// } // namespace TBD

#endif